feat(mlx): add handler for aten.roll#19038
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19038
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below:
|
This PR needs a
|
Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema table, the custom handler for the (shifts, dims) args, the exec_roll runtime, and test cases covering 1D, 2D, multi-axis, negative shifts, and negative dims. Flat roll (dims=[]) is explicitly NotImplementedError for now; all known use cases (Swin Transformer shift-window attention) pass dims. Fixes pytorch#18919
53b77ef to
726c721
Compare
Summary
Adds an MLX delegate handler for
aten.roll, mappingtorch.rollontomlx::core::rollvia a newRollNodein the schema. Replaces the default decomposition (index_select + arange + cat) with a single native kernel — needed by Swin Transformer's shift-window attention.Flat roll (
dims=[]) raisesNotImplementedErrorfor now; no known consumer needs it yet.Generated files (
MLXLoader.*,schema_generated.h,mlx_graph_schema.py,_generated_serializers.py,_generated_inspector.py,_generated/) are regenerated fromschema.fbsbybackends/mlx/CMakeLists.txtat build time and are deliberately not committed.Fixes #18919.
Test plan
python backends/mlx/serialization/generate.py— regenerates cleanly withRollNodein all expected outputs.lintrunner --skip MYPY --paths-cmd 'git diff --name-only upstream/main'— no issues.run_all_tests -k rollnot run locally (no executorch build on this machine); relying on CI. Happy to push fixes if it finds anything.